# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Produce the base dataset. """

import sys
import numpy as np
from mindvision.common.check_param import Validator
from .mnist import mnist_dataset  # pylint: disable=unused-import
from .imagenet2012 import imagenet2012_dataset  # pylint: disable=unused-import

DATASETS_TYPE = ['train', 'eval', 'infer']


class CustomDataset():
    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)


def create_dataset(config, states="train"):
    """
    A source dataset for reading and parsing dataset

    Args:
        config: config parameter
        states: str, such as 'train' or 'eval'

    Returns: dataset

    """
    # states value and type checking
    Validator.check_string(states.lower(), ['train', 'eval', 'infer'])

    if states == "train":
        dataset_config = config.train_data
    else:
        dataset_config = config.valid_data

    # Read the dataset according to the SELECT_DATASET
    try:
        print("start to create datasets.....")
        mod = sys.modules[__name__]
        dataset_function = getattr(mod, dataset_config.dataset_name)
        dataset = dataset_function(dataset_path=dataset_config.data_dir,
                                   states=states,
                                   batch_size=dataset_config.batch_size,
                                   shuffle=dataset_config.shuffle,
                                   repeat_num=dataset_config.repeat_num,
                                   num_parallel_workers=dataset_config.num_parallel_workers,
                                   enable_cache=config.enable_cache,
                                   cache_session_id=config.cache_session_id,
                                   distribute=config.run_distribute,
                                   config=config)
        print("Create datasets done.....")
        return dataset
    except AttributeError:
        raise AttributeError("{} is not in the custom dataset".format(dataset_config.dataset_name))
